查看原文
其他

深度学习之 图解LSTM

Shi Yan 大邓和他的Python 2019-04-26

原文标题: Understanding LSTM and its diagrams
原文链接:https://medium.com/mlreview/understanding-lstm-and-its-diagrams-37e2f46f1714
作者:Shi Yan
译者:大邓


尽管我们不知道大脑运行机制,但我们知道一定有存在 逻辑单元记忆单元。我们依据已有的经验,根据对具体情况的分析判断来做出决策。类比电脑, 人脑应该有逻辑单元、CPU和GPUS,我们也应该有内存(memory)。

但当我们看待一个神经网络时却经常将其错看成一个黑箱。我们似乎仅仅知道给这个黑箱输入一些数据,就会得到一些输出数据。而决策的制定通常基于当前输入的数据。

我觉得大家对神经网络没有记忆这种观点是错误的。毕竟,从数据中学到的权重参数某种程度上就是记忆,只不过这些记忆是静态不变的。有时候我们想将学习到的知识应用到未知或者未来的场景中去。例如,股票市场的预测、做完形填空题等。

最粗暴的实现方法是对每一个时刻的数据训练一个神经网络NNt,然后将所有时刻的神经网络NNt组成更大的神经网络BigNN。每个神经网络仅仅处理某个时刻的数据。我们可以将所有的数据提供给BigNN,而不是单独的提供给NNt。

我们想到的组织NNt成BigNN,更专业的叫法叫做循环神经网络(Recurrent neural network)。但在实际使用过程中,遇到两个问题:梯度消失和梯度爆炸,这让我们的构想实现起来问题多多,困难重重。

基于这个背景,LSTM(Long short term memeory)应运而生。下图是LSTM的最基本的组件。

一、LSTM的输入与输出

乍一看有点懵逼。让我们首先忽略内部细节,仅仅看这个神经网络的 输入什么输出什么。当前神经网络整体(当前LSTM)来看,需要三种输入数据:

  • X_t当前时间输进LSTM网络的数据

  • h_t-1 是前一个LSTM网络输出的结果

  • C_t-1 是之前的LSTM网络调整后的记忆,我认为 C_t-1 是最重要的输入

至于当前神经网络(当前LSTM)的输出有两个:

  • h_t 当前LSTM网络输出的结果

  • C_t 是当前LSTM网络调整后的记忆

因此,刚刚我们解读的LSTM会依据 当前输入的数据、前一期的输出和前一期的记忆。并且生成了 新的输出、修改了记忆

二、记忆的调整

而在LSTM网络内部中的 记忆C_t 的改变就像是管道里的水。改变记忆水管中水流量的方法有两种,一种是 遗忘,另一种是 更新信息,这里我简称为更新

遗忘

在LSTM模型图中,图中最上面的有色区域是管道。输入的是旧的记忆(vector)。X表示过去记忆在当前LSTM中的保留率。毕竟如果所有的记忆都要保留不论机器还是人脑都是受不了的,所以网络总会有一个遗忘率。所以如果你想对 旧的记忆C_t-1  乘以 一个越接近0 的 vector, 那么就意味着 该LSTM要遗忘绝大部分记忆。如果想保留所有旧的记忆,那么久让 旧的记忆C_t-1 乘以 等于1 的 vector。

更新

第二种操作是 更新信息,类似于水管中的 +

三、遗忘网络

现在我们看记忆水管的 X ,上图中的彩色部分是LSTM的遗忘网络,该网络计算的结果是 遗忘率, 该部分是被一层简单的神经网络所控制。 这个计算遗忘率的神经网络的输入的内容包括:

  • X_t: 当前LSTM的输入

  • C_t-1: 上一期LSTM调整后的记忆

  • h_t-1: 上一期LSTM的输出结果

  • bias: 偏差

这个神经网络使用 sigmoid 函数做为激活函数,其计算得到的向量就是遗忘率(vector)。

四、更新网络

现在我们在看看 更新网络。这也是一个简单的神经网络,且需要输入的数据与 遗忘网络 相同。该网络的计算结果是用来让新信息影响旧知识,更确切的说叫做 更新知识

五、新记忆

新记忆自身是通过另外一个简单的神经网络生成的,该网络使用tanh作为激活函数。

六、LSTM的输出h_t

最后,我们需要生成当前LSTM网络的输出 h_t。 h_t是由一下几个因素计算得来,分别:

  • C_t: 当前LSTM调整后的记忆

  • h_t-1: 前一期LSTM的输出结果

  • X_t: 当前时间输进LSTM网络的数据

  • bias: 偏差

LSTM的输出h_t 控制着多少新的信息应该传递给下一期的LSTM

往期文章

《用Python做文本分析》视频课程 

10分钟理解深度学习中的~卷积~

100G Python学习资料(免费下载) 

100G 文本分析语料资源(免费下载)    

typing库:让你的代码阅读者再也不用猜猜猜  

Seaborn官方教程中文教程(一)

数据清洗 常用正则表达式大全

大邓强力推荐-jupyter notebook使用小技巧  

PySimpleGUI: 开发自己第一个软件

深度特征合成:自动生成机器学习中的特征

Python 3.7中dataclass的终极指南(一) 

Python 3.7中dataclass的终极指南(二) 

15个最好的数据科学领域Python库    

使用Pandas更好的做数据科学

[计算消费者的偏好]推荐系统与协同过滤、奇异值分解

机器学习: 识别图片中的数字

应用PCA降维加速模型训练

如何从文本中提取特征信息?

使用sklearn做自然语言处理-1 

使用sklearn做自然语言处理-2

机器学习|八大步骤解决90%的NLP问题    

Python圈中的符号计算库-Sympy

Python中处理日期时间库的使用方法 

【视频讲解】Scrapy递归抓取简书用户信息

美团商家信息采集神器 

用chardect库解决网页乱码问题

    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存